package com.embeddings.service;


import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch._types.Script;
import co.elastic.clients.elasticsearch._types.query_dsl.*;
import co.elastic.clients.elasticsearch.core.IndexResponse;
import co.elastic.clients.elasticsearch.core.SearchResponse;
import co.elastic.clients.elasticsearch.indices.CreateIndexRequest;
import co.elastic.clients.elasticsearch.indices.CreateIndexResponse;
import co.elastic.clients.json.JsonData;
import com.embeddings.bean.SearchResult;
import com.embeddings.bean.Sell;
import com.embeddings.embed.EmbedClient;
import jakarta.annotation.Resource;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import java.io.IOException;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

@Service
public class EsDocumentService {

    @Value("${embedding.uri}")
    private String embeddingUri;

    @Value("${embedding.api-key}")
    private String embeddingApiKey;

    @Resource
    private ElasticsearchClient client;

    public static final String INDEX_NAME = "sell_service";

    public static final float SIMILARITY_THRESHOLD  = 0.2f;

    /**
     * 创建索引
     * @throws IOException 异常
     */
    public void createIndex() throws IOException {
        CreateIndexRequest request = new CreateIndexRequest.Builder()
                .index("sell_service")
                .mappings(m -> m
                        .properties("remark_vec", p -> p
                                .denseVector(dv -> dv
                                        .dims(1024)
                                        .index(true)
                                        .similarity("cosine")
                                )
                        )
                        .properties("remark", p -> p
                                .text(t -> t
                                        .analyzer("ik_smart") // 使用 IK 分词器
                                )
                        )
                )
                .build();

        CreateIndexResponse createIndexResponse = client.indices().create(request);
        System.out.println("Index created: " + createIndexResponse.acknowledged());
    }

    /**
     * 添加数据
     * @param sellList 数据
     * @throws IOException 异常
     */
    public void indexSellList(List<Sell> sellList) throws IOException {
        for (Sell sell : sellList) {
            sell.setRemark_vec(EmbedClient.getEmbedding(embeddingUri, embeddingApiKey, sell.getRemark()));
            IndexResponse response = client.index(i -> i
                    .index(INDEX_NAME)
                    .id(sell.getId())
                    .document(sell)
            );
            System.out.println("Sell indexed: " + response.id());
        }
    }


    /**
     * 检索
     * @param queryVector 向量
     * @return 结果
     * @throws IOException 异常
     */
    public List<SearchResult> searchWithGivenVector(double[] queryVector) throws IOException {
        // 创建向量相似度查询
        ScriptScoreQuery scriptScoreQuery = ScriptScoreQuery.of(q -> q
                .query(QueryBuilders.matchAll().build()._toQuery())
                .script(Script.of(s -> s.inline(i -> i
                        .source("double score = cosineSimilarity(params.query_vector, 'remark_vec'); " +
                                "score = Math.min(1.0, Math.max(0.0, score)); " + // 确保评分在[0, 1]之间
                                "if (score < params.threshold) { return 0; } else { return score; }")
                        .params(Map.of(
                                "query_vector", JsonData.of(queryVector),
                                "threshold", JsonData.of(SIMILARITY_THRESHOLD) // 将阈值作为参数传递给脚本
                        ))))));

        // 创建bool查询，向量相似度查询作为should子句
        Query boolQuery = QueryBuilders.bool(b -> b
                .should(scriptScoreQuery._toQuery())
        );

        Query functionScoreQuery = QueryBuilders.functionScore(fs -> fs
                .query(boolQuery)
                .scoreMode(FunctionScoreMode.Max)
                .boostMode(FunctionBoostMode.Replace)
                .minScore((double) SIMILARITY_THRESHOLD)
        );

        // 执行合并后的查询
        SearchResponse<Sell> combinedSearchResponse = client.search(s -> s
                        .index(INDEX_NAME)
                        .query(functionScoreQuery),
                Sell.class);

        // 处理查询的结果
        return combinedSearchResponse.hits().hits().stream()
                .map(hit -> {
                    double finalScore = Objects.nonNull(hit.score()) ? hit.score() : 0.0;
                    return finalScore >= SIMILARITY_THRESHOLD ? new SearchResult(hit.source(), finalScore) : null;
                })
                .filter(Objects::nonNull)
                .sorted(Comparator.comparingDouble(SearchResult::getScore).reversed())
                .collect(Collectors.toList());
    }
}